{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "d72fae4e-f7de-42b7-91ee-bdd0a57ae46c",
   "metadata": {},
   "source": [
    "# Prompt Generator\n",
    "\n",
    "In this example we will create a chat bot that helps a user generate a prompt.\n",
    "It will first collect requirements from the user, and then will generate the prompt (and refine it based on user input).\n",
    "These are split into two separate states, and the LLM decides when to transition between them.\n",
    "\n",
    "A graphical representation of the system can be found below.\n",
    "\n",
    "![](imgs/prompt-generator.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6d78b593-ba26-4c90-b2e2-83119e47679f",
   "metadata": {},
   "source": [
    "## Gather information\n",
    "\n",
    "First, let's define the part of the graph that will gather user requirements. This will be an LLM call with a specific system message. It will have access to a tool that it can call when it is ready to generate the prompt."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "53216ab5-2cd3-48a4-8778-41ba10f72519",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain_core.messages import SystemMessage\n",
    "from langchain_openai import ChatOpenAI\n",
    "from langchain_core.pydantic_v1 import BaseModel, Field\n",
    "from typing import List"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "5f795b78-004d-40ca-95d6-069f67e4f9c9",
   "metadata": {},
   "outputs": [],
   "source": [
    "template = \"\"\"Your job is to get information from a user about what type of prompt template they want to create.\n",
    "\n",
    "You should get the following information from them:\n",
    "\n",
    "- What the objective of the prompt is\n",
    "- What variables will be passed into the prompt template\n",
    "- Any constraints for what the output should NOT do\n",
    "- Any requirements that the output MUST adhere to\n",
    "\n",
    "If you are not able to discerne this info, ask them to clarify! Do not attempt to wildly guess.\n",
    "\n",
    "After you are able to discerne all the information, call the relevant tool\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "153a7c89-ec75-430e-96ac-6294d4e6ea2b",
   "metadata": {},
   "outputs": [],
   "source": [
    "llm = ChatOpenAI(temperature=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "8325ed0e-177a-4129-b4a4-cd54d527c634",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_messages_info(messages):\n",
    "    return [SystemMessage(content=template)] + messages\n",
    "\n",
    "\n",
    "class PromptInstructions(BaseModel):\n",
    "    \"\"\"Instructions on how to prompt the LLM.\"\"\"\n",
    "    objective: str\n",
    "    variables: List[str]\n",
    "    constraints: List[str]\n",
    "    requirements: List[str]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "95bb557a-47f2-46fd-adf7-9cf0e8ffe51c",
   "metadata": {},
   "outputs": [],
   "source": [
    "llm_with_tool = llm.bind_tools([PromptInstructions])\n",
    "\n",
    "chain = get_messages_info | llm_with_tool"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bb40630f-83c7-4283-a6dd-04231805a7ed",
   "metadata": {},
   "source": [
    "## Generate Prompt\n",
    "\n",
    "We now set up the state that will generate the prompt.\n",
    "This will require a separate system message, as well as a function to filter out all message PRIOR to the tool invocation (as that is when the previous state decided it was time to generate the prompt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "5b6736fb-21d7-44d3-a3fe-460f15945654",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Helper function for determining if tool was called\n",
    "def _is_tool_call(msg):\n",
    "    return hasattr(msg, \"additional_kwargs\") and 'tool_calls' in msg.additional_kwargs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "ca9a0234-bbeb-4bff-8276-8dde499c3390",
   "metadata": {},
   "outputs": [],
   "source": [
    "# New system prompt\n",
    "prompt_system = \"\"\"Based on the following requirements, write a good prompt template:\n",
    "\n",
    "{reqs}\"\"\"\n",
    "\n",
    "# Function to get the messages for the prompt\n",
    "# Will only get messages AFTER the tool call\n",
    "def get_prompt_messages(messages):\n",
    "    tool_call = None\n",
    "    other_msgs = []\n",
    "    for m in messages:\n",
    "        if _is_tool_call(m):\n",
    "            tool_call = m.additional_kwargs['tool_calls'][0]['function']['arguments']\n",
    "        elif tool_call is not None:\n",
    "            other_msgs.append(m)\n",
    "    return [SystemMessage(content=prompt_system.format(reqs=tool_call))] + other_msgs\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3293d2d3-e060-4c98-8065-3696b049b5fc",
   "metadata": {},
   "outputs": [],
   "source": [
    "prompt_gen_chain = get_prompt_messages | llm"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8dbabda8-34f0-4eef-bce2-ad3ff505366b",
   "metadata": {},
   "source": [
    "## Define the state logic\n",
    "\n",
    "This is the logic for what state the chatbot is in.\n",
    "If the last message is a tool call, then we are in the state where the \"prompt creator\" (`prompt`) should respond.\n",
    "Otherwise, if the last message is not a HumanMessage, then we know the human should respond next and so we are in the `END` state.\n",
    "If the last message is a HumanMessage, then if there was a tool call previously we are in the `prompt` state.\n",
    "Otherwise, we are in the \"info gathering\" (`info`) state."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "74f29e15-20e2-420c-a450-84e929f16e4e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_state(messages):\n",
    "    if _is_tool_call(messages[-1]):\n",
    "        return \"prompt\"\n",
    "    elif not isinstance(messages[-1], HumanMessage):\n",
    "        return END\n",
    "    for m in messages:\n",
    "        if _is_tool_call(m):\n",
    "            return \"prompt\"\n",
    "    return \"info\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b76bea78-07a5-418f-9b7c-71c376d4b6f7",
   "metadata": {},
   "source": [
    "## Create the graph\n",
    "\n",
    "We can now the create the graph.\n",
    "We will use a SqliteSaver to persist conversation history."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "59d9d6b4-dce4-43cc-9a1a-61a7912ed5b8",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langgraph.graph import MessageGraph, END\n",
    "from langgraph.checkpoint.sqlite import SqliteSaver\n",
    "\n",
    "memory = SqliteSaver.from_conn_string(\":memory:\")\n",
    "\n",
    "nodes = {k:k for k in ['info', 'prompt', END]}\n",
    "workflow = MessageGraph()\n",
    "workflow.add_node(\"info\", chain)\n",
    "workflow.add_node(\"prompt\", prompt_gen_chain)\n",
    "workflow.add_conditional_edges(\"info\", get_state, nodes)\n",
    "workflow.add_conditional_edges(\"prompt\", get_state, nodes)\n",
    "workflow.set_entry_point(\"info\")\n",
    "graph = workflow.compile(checkpointer=memory)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "afcf523c-265d-45cf-a981-fc50c50c1738",
   "metadata": {},
   "source": [
    "## Use the graph\n",
    "\n",
    "We can now use the created chatbot."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "25793988-45a2-4e65-b33c-64e72aadb10e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdin",
     "output_type": "stream",
     "text": [
      "User (q/Q to quit):  hi!\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Output from node 'info':\n",
      "---\n",
      "content='Hello! How can I assist you today?'\n",
      "\n",
      "---\n",
      "\n"
     ]
    },
    {
     "name": "stdin",
     "output_type": "stream",
     "text": [
      "User (q/Q to quit):  build me a prompt for extraction\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Output from node 'info':\n",
      "---\n",
      "content='Sure! I can help you with that. Could you please provide me with more details about the prompt you want to create? Specifically, I need to know the objective of the prompt, the variables that will be passed into the prompt template, any constraints for what the output should not do, and any requirements that the output must adhere to.'\n",
      "\n",
      "---\n",
      "\n"
     ]
    },
    {
     "name": "stdin",
     "output_type": "stream",
     "text": [
      "User (q/Q to quit):  i want to do extraction over a page\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Output from node 'info':\n",
      "---\n",
      "content='Great! Could you please provide me with more details about the objective of the extraction? What specific information are you looking to extract from the page?'\n",
      "\n",
      "---\n",
      "\n"
     ]
    },
    {
     "name": "stdin",
     "output_type": "stream",
     "text": [
      "User (q/Q to quit):  i want the user to specify that at run time\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Output from node 'info':\n",
      "---\n",
      "content=\"Understood. So the objective of the prompt is to allow the user to specify the information they want to extract from a page at runtime. \\n\\nNow, let's move on to the variables. Are there any specific variables that you would like to pass into the prompt template? For example, the URL of the page or any other parameters that might be relevant for the extraction process.\"\n",
      "\n",
      "---\n",
      "\n"
     ]
    },
    {
     "name": "stdin",
     "output_type": "stream",
     "text": [
      "User (q/Q to quit):  the schema to extract, and the text to extract it from\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Output from node 'info':\n",
      "---\n",
      "content='Got it. So the variables that will be passed into the prompt template are the schema to extract and the text to extract it from.\\n\\nNext, are there any constraints for what the output should not do? For example, should the output not include any sensitive information or should it not exceed a certain length?'\n",
      "\n",
      "---\n",
      "\n"
     ]
    },
    {
     "name": "stdin",
     "output_type": "stream",
     "text": [
      "User (q/Q to quit):  it must be in json\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Output from node 'info':\n",
      "---\n",
      "content='Understood. So a requirement for the output is that it must be in JSON format.\\n\\nLastly, are there any specific requirements that the output must adhere to? For example, should the output follow a specific structure or include certain fields?'\n",
      "\n",
      "---\n",
      "\n"
     ]
    },
    {
     "name": "stdin",
     "output_type": "stream",
     "text": [
      "User (q/Q to quit):  must be json, must include the same fields as the schema specified\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Output from node 'info':\n",
      "---\n",
      "content='Got it. So the requirements for the output are that it must be in JSON format and it must include the same fields as the schema specified.\\n\\nBased on the information you provided, I will now generate the prompt template for extraction. Please give me a moment.\\n\\n' additional_kwargs={'tool_calls': [{'id': 'call_6roy9dQoIrQZsHffR9kjAr0e', 'function': {'arguments': '{\\n  \"objective\": \"Extract specific information from a page\",\\n  \"variables\": [\"schema\", \"text\"],\\n  \"constraints\": [\"Output should not include sensitive information\", \"Output should not exceed a certain length\"],\\n  \"requirements\": [\"Output must be in JSON format\", \"Output must include the same fields as the specified schema\"]\\n}', 'name': 'PromptInstructions'}, 'type': 'function'}]}\n",
      "\n",
      "---\n",
      "\n",
      "Output from node 'prompt':\n",
      "---\n",
      "content='Extract specific information from a page and output the result in JSON format. The input page should contain the following fields: {{schema}}. The extracted information should be stored in the variable {{text}}. Ensure that the output does not include any sensitive information and does not exceed a certain length. Additionally, the output should include the same fields as the specified schema.'\n",
      "\n",
      "---\n",
      "\n"
     ]
    },
    {
     "name": "stdin",
     "output_type": "stream",
     "text": [
      "User (q/Q to quit):  q\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "AI: Byebye\n"
     ]
    }
   ],
   "source": [
    "import uuid\n",
    "from langchain_core.messages import HumanMessage\n",
    "\n",
    "config = {\"configurable\": {\"thread_id\": str(uuid.uuid4())}}\n",
    "while True:\n",
    "    user = input('User (q/Q to quit): ')\n",
    "    if user in {'q', 'Q'}:\n",
    "        print('AI: Byebye')\n",
    "        break\n",
    "    for output in graph.stream([HumanMessage(content=user)], config=config):\n",
    "        if \"__end__\" in output:\n",
    "            continue\n",
    "        # stream() yields dictionaries with output keyed by node name\n",
    "        for key, value in output.items():\n",
    "            print(f\"Output from node '{key}':\")\n",
    "            print(\"---\")\n",
    "            print(value)\n",
    "        print(\"\\n---\\n\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a276d20e-8a1b-4add-bf8d-83a8c803431d",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
